{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp callback.tensorboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#all_slow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.basics import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tensorboard\n",
    "\n",
    "> Integration with [tensorboard](https://www.tensorflow.org/tensorboard) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First thing first, you need to install tensorboard with\n",
    "```\n",
    "pip install tensorboard\n",
    "```\n",
    "Then launch tensorboard with\n",
    "``` \n",
    "tensorboard --logdir=runs\n",
    "```\n",
    "in your terminal. You can change the logdir as long as it matches the `log_dir` you pass to `TensorBoardCallback` (default is `runs` in the working directory)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tensorboard Embedding Projector support"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Tensorboard Embedding Projector is currently only supported for image classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Export Image Feutures during Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tensorboard [Embedding Projector](https://www.tensorflow.org/tensorboard/tensorboard_projector_plugin) is supported in `TensorBoardCallback` (set parameter `projector=True`) during training. The validation set embeddings will be written after each epoch.\n",
    "\n",
    "```\n",
    "cbs = [TensorBoardCallback(projector=True)]\n",
    "learn = cnn_learner(dls, resnet18, metrics=accuracy)\n",
    "learn.fit_one_cycle(3, cbs=cbs)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Export Image Features during Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To write the embeddings for a custom dataset (e. g. after loading a learner) use `TensorBoardProjectorCallback`. Add the callback manually to the learner.\n",
    "\n",
    "```\n",
    "learn = load_learner('path/to/export.pkl')\n",
    "learn.add_cb(TensorBoardProjectorCallback())\n",
    "dl = learn.dls.test_dl(files, with_labels=True)\n",
    "_ = learn.get_preds(dl=dl)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If using a custom model (non fastai-resnet) pass the layer where the embeddings should be extracted as a callback-parameter.\n",
    "\n",
    "```\n",
    "layer = learn.model[1][1]\n",
    "cbs = [TensorBoardProjectorCallback(layer=layer)]\n",
    "preds = learn.get_preds(dl=dl, cbs=cbs)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Export Word Embeddings from Language Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To export word embeddings from Language Models (tested with AWD_LSTM (fast.ai) and GPT2 / BERT (transformers)) but works with every model that contains an embedding layer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For a **fast.ai TextLearner or LMLearner** just pass the learner - the embedding layer and vocab will be extracted automatically:\n",
    "```\n",
    "dls = TextDataLoaders.from_folder(untar_data(URLs.IMDB), valid='test')\n",
    "learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)\n",
    "projector_word_embeddings(learn=learn, limit=2000, start=2000)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For other language models - like the ones in the **transformers library** - you'll have to pass the layer and vocab. Here's an example for a **BERT** model.\n",
    "```\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "model = AutoModel.from_pretrained(\"bert-base-uncased\")\n",
    "\n",
    "# get the word embedding layer\n",
    "layer = model.embeddings.word_embeddings\n",
    "\n",
    "# get and sort vocab\n",
    "vocab_dict = tokenizer.get_vocab()\n",
    "vocab = [k for k, v in sorted(vocab_dict.items(), key=lambda x: x[1])]\n",
    "\n",
    "# write the embeddings for tb projector\n",
    "projector_word_embeddings(layer=layer, vocab=vocab, limit=2000, start=2000)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "import tensorboard\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from fastai.callback.fp16 import ModelToHalf\n",
    "from fastai.callback.hook import hook_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TensorBoardBaseCallback(Callback):\n",
    "    \"Base class for tensorboard callbacks\"\n",
    "    def __init__(self):\n",
    "        self.run_projector = False\n",
    "        \n",
    "    def after_pred(self):\n",
    "        if self.run_projector: self.feat = _add_projector_features(self.learn, self.h, self.feat)\n",
    "    \n",
    "    def after_validate(self):\n",
    "        if not self.run_projector: return\n",
    "        self.run_projector = False\n",
    "        self._remove()\n",
    "        _write_projector_embedding(self.learn, self.writer, self.feat)\n",
    "            \n",
    "    def after_fit(self): \n",
    "        if self.run: self.writer.close()\n",
    "        \n",
    "    def _setup_projector(self):\n",
    "        self.run_projector = True\n",
    "        self.h = hook_output(self.learn.model[1][1] if not self.layer else self.layer)\n",
    "        self.feat = {}\n",
    "        \n",
    "    def _setup_writer(self):\n",
    "        self.writer = SummaryWriter(log_dir=self.log_dir)\n",
    "    \n",
    "    def _remove(self):\n",
    "        if getattr(self, 'h', None): self.h.remove()\n",
    "\n",
    "    def __del__(self): self._remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TensorBoardCallback(TensorBoardBaseCallback):\n",
    "    \"Saves model topology, losses & metrics for tensorboard and tensorboard projector during training\"\n",
    "    def __init__(self, log_dir=None, trace_model=True, log_preds=True, n_preds=9, projector=False, layer=None):\n",
    "        super().__init__()\n",
    "        store_attr()\n",
    "\n",
    "    def before_fit(self):\n",
    "        self.run = not hasattr(self.learn, 'lr_finder') and not hasattr(self, \"gather_preds\") and rank_distrib()==0\n",
    "        if not self.run: return\n",
    "        self._setup_writer()\n",
    "        if self.trace_model:\n",
    "            if hasattr(self.learn, 'mixed_precision'):\n",
    "                raise Exception(\"Can't trace model in mixed precision, pass `trace_model=False` or don't use FP16.\")\n",
    "            b = self.dls.one_batch()\n",
    "            self.learn._split(b)\n",
    "            self.writer.add_graph(self.model, *self.xb)\n",
    "\n",
    "    def after_batch(self):\n",
    "        self.writer.add_scalar('train_loss', self.smooth_loss, self.train_iter)\n",
    "        for i,h in enumerate(self.opt.hypers):\n",
    "            for k,v in h.items(): self.writer.add_scalar(f'{k}_{i}', v, self.train_iter)\n",
    "\n",
    "    def after_epoch(self):\n",
    "        for n,v in zip(self.recorder.metric_names[2:-1], self.recorder.log[2:-1]):\n",
    "            self.writer.add_scalar(n, v, self.train_iter)\n",
    "        if self.log_preds:\n",
    "            b = self.dls.valid.one_batch()\n",
    "            self.learn.one_batch(0, b)\n",
    "            preds = getattr(self.loss_func, 'activation', noop)(self.pred)\n",
    "            out = getattr(self.loss_func, 'decodes', noop)(preds)\n",
    "            x,y,its,outs = self.dls.valid.show_results(b, out, show=False, max_n=self.n_preds)\n",
    "            tensorboard_log(x, y, its, outs, self.writer, self.train_iter)\n",
    "            \n",
    "    def before_validate(self):\n",
    "        if self.projector: self._setup_projector()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TensorBoardProjectorCallback(TensorBoardBaseCallback):\n",
    "    \"Extracts and exports image featuers for tensorboard projector during inference\"\n",
    "    def __init__(self, log_dir=None, layer=None):\n",
    "        super().__init__()\n",
    "        store_attr()\n",
    "    \n",
    "    def before_fit(self):\n",
    "        self.run = not hasattr(self.learn, 'lr_finder') and hasattr(self, \"gather_preds\") and rank_distrib()==0\n",
    "        if not self.run: return\n",
    "        self._setup_writer()\n",
    "\n",
    "    def before_validate(self):\n",
    "        self._setup_projector()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _write_projector_embedding(learn, writer, feat):\n",
    "    lbls = [learn.dl.vocab[l] for l in feat['lbl']] if getattr(learn.dl, 'vocab', None) else None     \n",
    "    vecs = feat['vec'].squeeze()\n",
    "    writer.add_embedding(vecs, metadata=lbls, label_img=feat['img'], global_step=learn.train_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _add_projector_features(learn, hook, feat):\n",
    "    img = _normalize_for_projector(learn.x)\n",
    "    first_epoch = True if learn.iter == 0 else False\n",
    "    feat['vec'] = hook.stored if first_epoch else torch.cat((feat['vec'], hook.stored),0)\n",
    "    feat['img'] = img           if first_epoch else torch.cat((feat['img'], img),0)\n",
    "    if getattr(learn.dl, 'vocab', None):\n",
    "        feat['lbl'] = learn.y if first_epoch else torch.cat((feat['lbl'], learn.y),0)\n",
    "    return feat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _get_embeddings(model, layer):\n",
    "    layer = model[0].encoder if layer == None else layer\n",
    "    return layer.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def _normalize_for_projector(x:TensorImage):\n",
    "    # normalize tensor to be between 0-1\n",
    "    img = x.clone()\n",
    "    sz = img.shape\n",
    "    img = img.view(x.size(0), -1)\n",
    "    img -= img.min(1, keepdim=True)[0]\n",
    "    img /= img.max(1, keepdim=True)[0]\n",
    "    img = img.view(*sz)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "from fastai.text.all import LMLearner, TextLearner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def projector_word_embeddings(learn=None, layer=None, vocab=None, limit=-1, start=0, log_dir=None):\n",
    "    \"Extracts and exports word embeddings from language models embedding layers\"\n",
    "    if not layer:\n",
    "        if   isinstance(learn, LMLearner):   layer = learn.model[0].encoder\n",
    "        elif isinstance(learn, TextLearner): layer = learn.model[0].module.encoder\n",
    "    emb = layer.weight\n",
    "    img = torch.full((len(emb),3,8,8), 0.7)\n",
    "    vocab = learn.dls.vocab[0] if vocab == None else vocab\n",
    "    vocab = list(map(lambda x: f'{x}_', vocab))\n",
    "    writer = SummaryWriter(log_dir=log_dir)\n",
    "    end = start + limit if limit >= 0 else -1\n",
    "    writer.add_embedding(emb[start:end], metadata=vocab[start:end], label_img=img[start:end])\n",
    "    writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.vision.data import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def tensorboard_log(x:TensorImage, y: TensorCategory, samples, outs, writer, step):\n",
    "    fig,axs = get_grid(len(samples), add_vert=1, return_fig=True)\n",
    "    for i in range(2):\n",
    "        axs = [b.show(ctx=c) for b,c in zip(samples.itemgot(i),axs)]\n",
    "    axs = [r.show(ctx=c, color='green' if b==r else 'red')\n",
    "            for b,r,c in zip(samples.itemgot(1),outs.itemgot(0),axs)]\n",
    "    writer.add_figure('Sample results', fig, step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.vision.core import TensorPoint,TensorBBox"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def tensorboard_log(x:TensorImage, y: (TensorImageBase, TensorPoint, TensorBBox), samples, outs, writer, step):\n",
    "    fig,axs = get_grid(len(samples), add_vert=1, return_fig=True, double=True)\n",
    "    for i in range(2):\n",
    "        axs[::2] = [b.show(ctx=c) for b,c in zip(samples.itemgot(i),axs[::2])]\n",
    "    for x in [samples,outs]:\n",
    "        axs[1::2] = [b.show(ctx=c) for b,c in zip(x.itemgot(0),axs[1::2])]\n",
    "    writer.add_figure('Sample results', fig, step)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.vision.all import Resize, RandomSubsetSplitter, aug_transforms, cnn_learner, resnet18"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TensorBoardCallback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.PETS)\n",
    "\n",
    "db = DataBlock(blocks=(ImageBlock, CategoryBlock), \n",
    "                  get_items=get_image_files, \n",
    "                  item_tfms=Resize(128),\n",
    "                  splitter=RandomSubsetSplitter(train_sz=0.1, valid_sz=0.01),\n",
    "                  batch_tfms=aug_transforms(size=64),\n",
    "                  get_y=using_attr(RegexLabeller(r'(.+)_\\d+.*$'), 'name'))\n",
    "\n",
    "dls = db.dataloaders(path/'images')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(dls, resnet18, metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>5.099678</td>\n",
       "      <td>6.461753</td>\n",
       "      <td>0.054795</td>\n",
       "      <td>00:17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>4.268377</td>\n",
       "      <td>5.123801</td>\n",
       "      <td>0.095890</td>\n",
       "      <td>00:15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>3.713221</td>\n",
       "      <td>3.094202</td>\n",
       "      <td>0.178082</td>\n",
       "      <td>00:14</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.unfreeze()\n",
    "learn.fit_one_cycle(3, cbs=TensorBoardCallback(Path.home()/'tmp'/'runs'/'tb', trace_model=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Projector"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Projector in TensorBoardCallback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.PETS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "db = DataBlock(blocks=(ImageBlock, CategoryBlock), \n",
    "                  get_items=get_image_files, \n",
    "                  item_tfms=Resize(128),\n",
    "                  splitter=RandomSubsetSplitter(train_sz=0.05, valid_sz=0.01),\n",
    "                  batch_tfms=aug_transforms(size=64),\n",
    "                  get_y=using_attr(RegexLabeller(r'(.+)_\\d+.*$'), 'name'))\n",
    "\n",
    "dls = db.dataloaders(path/'images')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cbs = [TensorBoardCallback(log_dir=Path.home()/'tmp'/'runs'/'vision1', projector=True)]\n",
    "learn = cnn_learner(dls, resnet18, metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>5.187016</td>\n",
       "      <td>8.598857</td>\n",
       "      <td>0.054795</td>\n",
       "      <td>00:07</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>4.667810</td>\n",
       "      <td>6.271108</td>\n",
       "      <td>0.136986</td>\n",
       "      <td>00:07</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>4.169415</td>\n",
       "      <td>4.457378</td>\n",
       "      <td>0.136986</td>\n",
       "      <td>00:08</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.unfreeze()\n",
    "learn.fit_one_cycle(3, cbs=cbs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TensorBoardProjectorCallback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.PETS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "db = DataBlock(blocks=(ImageBlock, CategoryBlock), \n",
    "                  get_items=get_image_files, \n",
    "                  item_tfms=Resize(128),\n",
    "                  splitter=RandomSubsetSplitter(train_sz=0.1, valid_sz=0.01),\n",
    "                  batch_tfms=aug_transforms(size=64),\n",
    "                  get_y=using_attr(RegexLabeller(r'(.+)_\\d+.*$'), 'name'))\n",
    "\n",
    "dls = db.dataloaders(path/'images')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "files = get_image_files(path/'images')\n",
    "files = files[:256]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = learn.dls.test_dl(files, with_labels=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(dls, resnet18, metrics=accuracy)\n",
    "layer = learn.model[1][0].ap\n",
    "cbs = [TensorBoardProjectorCallback(layer=layer, log_dir=Path.home()/'tmp'/'runs'/'vision2')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "_ = learn.get_preds(dl=dl, cbs=cbs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## projector_word_embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### fastai text or lm learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.text.all import TextDataLoaders, text_classifier_learner, AWD_LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = TextDataLoaders.from_folder(untar_data(URLs.IMDB), valid='test')\n",
    "learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "projector_word_embeddings(learn, limit=1000, log_dir=Path.home()/'tmp'/'runs'/'text')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### GPT2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#from transformers import GPT2LMHeadModel, GPT2TokenizerFast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')\n",
    "#model = GPT2LMHeadModel.from_pretrained('gpt2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#layer = model.transformer.wte"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#vocab_dict = tokenizer.get_vocab()\n",
    "#vocab = [k for k, v in sorted(vocab_dict.items(), key=lambda x: x[1])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#projector_word_embeddings(layer=layer, vocab=vocab, limit=2000, log_dir=Path.home()/'tmp'/'runs'/'transformers')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#from transformers import AutoTokenizer, AutoModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "#model = AutoModel.from_pretrained(\"bert-base-uncased\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#layer = model.embeddings.word_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#vocab_dict = tokenizer.get_vocab()\n",
    "#vocab = [k for k, v in sorted(vocab_dict.items(), key=lambda x: x[1])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#projector_word_embeddings(layer=layer, vocab=vocab, limit=2000, start=2000, log_dir=Path.home()/'tmp'/'runs'/'transformers')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Validate results in tensorboard"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the following command in the command line to check if the projector embeddings have been correctly wirtten:\n",
    "\n",
    "```\n",
    "tensorboard --logdir=~/tmp/runs\n",
    "```\n",
    "\n",
    "Open http://localhost:6006 in browser (TensorBoard Projector doesn't work correctly in Safari!)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import *\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "fastaidev",
   "language": "python",
   "name": "fastaidev"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
